{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lf7huAiYp-An"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "cellView": "form",
        "id": "YHz2D-oIqBWa"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x44FFES-r6y0"
      },
      "source": [
        "# 텍스트 생성을 위한 Federated Learning"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iPFgLeZIsZ3Q"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\">TensorFlow.org에서 보기</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ko/federated/tutorials/federated_learning_for_text_generation.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Google Colab에서 실행하기</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ko/federated/tutorials/federated_learning_for_text_generation.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub에서소스 보기</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KbNz2tuvsAFB"
      },
      "source": [
        "**참고**: 이 Colab은 <code>tensorflow_federated</code> pip 패키지의 <a>최신 릴리즈 버전</a>에서 동작하는 것으로 확인되었지만, Tensorflow Federated 프로젝트는 아직 릴리즈 전 개발 중이며 `master`에서 동작하지 않을 수 있습니다.\n",
        "\n",
        "이 튜토리얼은 [이미지 분류를 위한 Federated Learning](federated_learning_for_image_classification.ipynb) 튜토리얼의 개념을 기반으로 하며, 페더레이션 학습을 위한 몇 가지 유용한 접근 방식을 보여줍니다.\n",
        "\n",
        "특히, 이전에 훈련된 Keras 모델을 로드하고 분산된 (시뮬레이션) 데이터세트에 대한 페더레이션 훈련을 사용하여 구체화합니다. 이 방법은 여러 가지 이유로 실질적으로 중요합니다. 직렬화된 모델을 사용하는 기능을 통해 페더레이션 학습을 다른 ML 접근 방식과 쉽게 혼합할 수 있습니다. 또한, 이를 통해 사전 훈련된 모델의 범위가 증가할 수 있습니다. 예를 들어, 사전 훈련된 수많은 모델이 현재 널리 사용 가능하기 때문에 처음부터 언어 모델을 훈련하는 것은 거의 필요하지 않습니다(예: [TF Hub 참조](https://www.tensorflow.org/hub)). 대신, 사전 훈련된 모델에서 시작하여 특정 애플리케이션에 대한 분산 데이터의 특정 특성에 적응하여 Federated Learning을 사용하여 구체화하는 것이 더 합리적입니다.\n",
        "\n",
        "이 튜토리얼에서는 ASCII 문자를 생성하는 RNN으로 시작하고 페더레이션 학습을 통해 구체화합니다. 또한, 최종 가중치를 원래 Keras 모델로 피드백하여 표준 도구를 사용하여 쉽게 평가하고 텍스트를 생성할 수 있는 방법을 보여줍니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9LcC1AwjoqfR"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install --quiet --upgrade tensorflow_federated_nightly\n",
        "!pip install --quiet --upgrade nest_asyncio\n",
        "\n",
        "import nest_asyncio\n",
        "nest_asyncio.apply()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "ZjDQysatrc2S"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "b'Hello, World!'"
            ]
          },
          "execution_count": 3,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import collections\n",
        "import functools\n",
        "import os\n",
        "import time\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_federated as tff\n",
        "\n",
        "np.random.seed(0)\n",
        "\n",
        "# Test the TFF is working:\n",
        "tff.federated_computation(lambda: 'Hello, World!')()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lyICXwVAxvW9"
      },
      "source": [
        "## 사전 훈련된 모델 로드하기\n",
        "\n",
        "TensorFlow 튜토리얼 [즉시 실행되는 RNN을 사용한 텍스트 생성](https://www.tensorflow.org/tutorials/sequences/text_generation)에 따라 사전 훈련된 모델을 로드합니다. 하지만, [셰익스피어의 전체 작품](http://www.gutenberg.org/files/100/100-0.txt)에 대한 훈련 대신 Charles Dickens의 [A Tale of Two Cities](http://www.ibiblio.org/pub/docs/books/gutenberg/9/98/98.txt) 및 [A Christmas Carol](http://www.ibiblio.org/pub/docs/books/gutenberg/4/46/46.txt)의 텍스트에 대해 모델을 사전 훈련했습니다.\n",
        "\n",
        "어휘 확장 이외에는 원래 튜토리얼을 수정하지 않았기 때문에 이 초기 모델은 최첨단이 아니지만, 합리적인 예측값을 생성하며, 튜토리얼 목적에는 충분합니다. 최종 모델은 `tf.keras.models.save_model(include_optimizer=False)`로 저장되었습니다.\n",
        "\n",
        "이 튜토리얼에서는 TFF에서 제공하는 데이터의 페더레이션 버전을 사용하여 셰익스피어에 대한 이 모델을 미세 조정하는 데 페더레이션 학습을 사용할 것입니다.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XgF8e2Ksyq1F"
      },
      "source": [
        "### 어휘 조회 테이블 생성하기"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "IlCgQBRVymwR"
      },
      "outputs": [],
      "source": [
        "# A fixed vocabularly of ASCII chars that occur in the works of Shakespeare and Dickens:\n",
        "vocab = list('dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\\'/37;?bfjnrvzBFJNRVZ\"&amp;*.26:\\naeimquyAEIMQUY]!%)-159\\r')\n",
        "\n",
        "# Creating a mapping from unique characters to indices\n",
        "char2idx = {u:i for i, u in enumerate(vocab)}\n",
        "idx2char = np.array(vocab)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2EH6MFRdzAwd"
      },
      "source": [
        "### 사전 훈련된 모델을 로드하고 일부 텍스트 생성하기"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "iIK674SrtCTm"
      },
      "outputs": [],
      "source": [
        "def load_model(batch_size):\n",
        "  urls = {\n",
        "      1: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel',\n",
        "      8: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel'}\n",
        "  assert batch_size in urls, 'batch_size must be in ' + str(urls.keys())\n",
        "  url = urls[batch_size]\n",
        "  local_file = tf.keras.utils.get_file(os.path.basename(url), origin=url)  \n",
        "  return tf.keras.models.load_model(local_file, compile=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "WvuwZBX5Ogfd"
      },
      "outputs": [],
      "source": [
        "def generate_text(model, start_string):\n",
        "  # From https://www.tensorflow.org/tutorials/sequences/text_generation\n",
        "  num_generate = 200\n",
        "  input_eval = [char2idx[s] for s in start_string]\n",
        "  input_eval = tf.expand_dims(input_eval, 0)\n",
        "  text_generated = []\n",
        "  temperature = 1.0\n",
        "\n",
        "  model.reset_states()\n",
        "  for i in range(num_generate):\n",
        "    predictions = model(input_eval)\n",
        "    predictions = tf.squeeze(predictions, 0)\n",
        "    predictions = predictions / temperature\n",
        "    predicted_id = tf.random.categorical(\n",
        "        predictions, num_samples=1)[-1, 0].numpy()\n",
        "    input_eval = tf.expand_dims([predicted_id], 0)\n",
        "    text_generated.append(idx2char[predicted_id])\n",
        "\n",
        "  return (start_string + ''.join(text_generated))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "MGAdStJ5wDPV"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel\n",
            "16195584/16193984 [==============================] - 0s 0us/step\n",
            "16203776/16193984 [==============================] - 0s 0us/step\n",
            "What of TensorFlow Federated, you ask? Sall\n",
            "yesterday. Received the Bailey.\"\n",
            "\n",
            "\"Mr. Lorry, grimmering himself, or low varked thends the winter, and the eyes of Monsieur\n",
            "Defarge. \"Let his mind, hon in his\n",
            "life and message; four declare\n"
          ]
        }
      ],
      "source": [
        "# Text generation requires a batch_size=1 model.\n",
        "keras_model_batch1 = load_model(batch_size=1)\n",
        "print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kKMUn-TlgxuP"
      },
      "source": [
        "## 페더레이션 셰익스피어 데이터 로드 및 전처리\n",
        "\n",
        "`tff.simulation.datasets` 패키지는 \"clients\"로 분할된 다양한 데이터세트를 제공합니다. 여기서 각 클라이언트는 페더레이션 학습에 참여할 수 있는 특정 기기의 데이터세트에 해당합니다.\n",
        "\n",
        "이들 데이터세트는 실제 분산된 데이터에 대한 훈련 문제를 시뮬레이션에서 복제하는 현실적인 비 IID 데이터 분산을 제공합니다. 이 데이터의 일부 전처리는 [Leaf 프로젝트](https://arxiv.org/abs/1812.01097)([github](https://github.com/TalwalkarLab/leaf))의 도구를 사용하여 수행되었습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "di3nStTDg0qc"
      },
      "outputs": [],
      "source": [
        "train_data, test_data = tff.simulation.datasets.shakespeare.load_data()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_iiY65Vv4QNK"
      },
      "source": [
        "`shakespeare.load_data()`에서 제공하는 데이터세트는 셰익스피어 연극의 특정 캐릭터가 말한 각 대사에 하나씩, 문자열 `Tensors`의 시퀀스로 구성됩니다. 클라이언트 키는 캐릭터의 이름과 결합된 연극의 이름으로 구성됩니다. 예를 들어, `MUCH_ADO_ABOUT_NOTHING_OTHELLO`는 연극 *Much Ado About Nothing*의 오델로 캐릭터 대사에 해당합니다. 실제 페더레이션 학습 시나리오에서 클라이언트는 ID로 식별되거나 추적되지 않지만, 시뮬레이션의 경우 키가 지정된 데이터세트로 작업하는 것이 유용합니다.\n",
        "\n",
        "예를 들어, 여기에서 King Lear의 일부 데이터를 볼 수 있습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "FEKiy1ntmmnk"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tf.Tensor(b'', shape=(), dtype=string)\n",
            "tf.Tensor(b'What?', shape=(), dtype=string)\n"
          ]
        }
      ],
      "source": [
        "# Here the play is \"The Tragedy of King Lear\" and the character is \"King\".\n",
        "raw_example_dataset = train_data.create_tf_dataset_for_client(\n",
        "    'THE_TRAGEDY_OF_KING_LEAR_KING')\n",
        "# To allow for future extensions, each entry x\n",
        "# is an OrderedDict with a single key 'snippets' which contains the text.\n",
        "for x in raw_example_dataset.take(2):\n",
        "  print(x['snippets'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kUnbI5Hp4sXg"
      },
      "source": [
        "이제 `tf.data.Dataset` 변환을 사용하여 위에 로드된 문자 RNN을 훈련하기 위한 이 데이터를 준비합니다.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "9kDkmGe-7No7"
      },
      "outputs": [],
      "source": [
        "# Input pre-processing parameters\n",
        "SEQ_LENGTH = 100\n",
        "BATCH_SIZE = 8\n",
        "BUFFER_SIZE = 100  # For dataset shuffling"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "W95Of6Bwsrfc"
      },
      "outputs": [],
      "source": [
        "# Construct a lookup table to map string chars to indexes,\n",
        "# using the vocab loaded above:\n",
        "table = tf.lookup.StaticHashTable(\n",
        "    tf.lookup.KeyValueTensorInitializer(\n",
        "        keys=vocab, values=tf.constant(list(range(len(vocab))),\n",
        "                                       dtype=tf.int64)),\n",
        "    default_value=0)\n",
        "\n",
        "\n",
        "def to_ids(x):\n",
        "  s = tf.reshape(x['snippets'], shape=[1])\n",
        "  chars = tf.strings.bytes_split(s).values\n",
        "  ids = table.lookup(chars)\n",
        "  return ids\n",
        "\n",
        "\n",
        "def split_input_target(chunk):\n",
        "  input_text = tf.map_fn(lambda x: x[:-1], chunk)\n",
        "  target_text = tf.map_fn(lambda x: x[1:], chunk)\n",
        "  return (input_text, target_text)\n",
        "\n",
        "\n",
        "def preprocess(dataset):\n",
        "  return (\n",
        "      # Map ASCII chars to int64 indexes using the vocab\n",
        "      dataset.map(to_ids)\n",
        "      # Split into individual chars\n",
        "      .unbatch()\n",
        "      # Form example sequences of SEQ_LENGTH +1\n",
        "      .batch(SEQ_LENGTH + 1, drop_remainder=True)\n",
        "      # Shuffle and form minibatches\n",
        "      .shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)\n",
        "      # And finally split into (input, target) tuples,\n",
        "      # each of length SEQ_LENGTH.\n",
        "      .map(split_input_target))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jw98HnKmEhuh"
      },
      "source": [
        "원래 시퀀스의 형성과 위의 배치 형성에서는 단순성을 위해 `drop_remainder=True`를 사용합니다. 즉, 최소한 `(SEQ_LENGTH + 1) * BATCH_SIZE` 문자가 없는 모든 문자(클라이언트)는 빈 데이터세트를 갖게 됩니다. 이를 해결하기 위한 일반적인 접근 방식은 배치를 특수 토큰으로 채운 다음 해당 토큰을 고려하지 않도록 손실을 마스크하는 것입니다.\n",
        "\n",
        "이로 인해 예제가 다소 복잡해지므로 이 튜토리얼에서는 [표준 튜토리얼](https://www.tensorflow.org/tutorials/sequences/text_generation)에서와 같이 전체 배치만 사용합니다. 그러나 페더레이션 설정에서는 많은 사용자가 작은 데이터세트를 가질 수 있으므로 이 문제가 더 중요합니다.\n",
        "\n",
        "이제 `raw_example_dataset`를 전처리하고 유형을 확인할 수 있습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "7rTal7bksWwc"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(TensorSpec(shape=(8, 100), dtype=tf.int64, name=None), TensorSpec(shape=(8, 100), dtype=tf.int64, name=None))\n"
          ]
        }
      ],
      "source": [
        "example_dataset = preprocess(raw_example_dataset)\n",
        "print(example_dataset.element_spec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ePT8Oawm8SRP"
      },
      "source": [
        "## 모델 컴파일 및 전처리된 데이터로 테스트하기"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vEgDsz-48cAq"
      },
      "source": [
        "컴파일되지 않은 keras 모델을 로드했지만, `keras_model.evaluate`를 실행하려면 손실 및 메트릭을 사용하여 컴파일해야 합니다. 또한, 페더레이션 학습에서 기기 내 옵티마이저로 사용될 옵티마이저에서 컴파일할 것입니다."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RsuVZ5KMWnn8"
      },
      "source": [
        "원래 튜토리얼에는 문자 수준의 정확성은 없었습니다(올바른 다음 문자에 가장 높은 확률을 부여한 예측 비율). 이 정확성은 유용한 메트릭이므로 추가합니다. 그러나, 예측값은 순위 3(각 `BATCH_SIZE * SEQ_LENGTH` 예측값에 대한 로짓 벡터)이고 `SparseCategoricalAccuracy`는 순위 2 예측값만 기대하므로 정확성에 대한 새 메트릭 클래스를 정의해야 합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "gOUiDBvmWlM9"
      },
      "outputs": [],
      "source": [
        "class FlattenedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):\n",
        "\n",
        "  def __init__(self, name='accuracy', dtype=tf.float32):\n",
        "    super().__init__(name, dtype=dtype)\n",
        "\n",
        "  def update_state(self, y_true, y_pred, sample_weight=None):\n",
        "    y_true = tf.reshape(y_true, [-1, 1])\n",
        "    y_pred = tf.reshape(y_pred, [-1, len(vocab), 1])\n",
        "    return super().update_state(y_true, y_pred, sample_weight)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U2X9eFgt94PM"
      },
      "source": [
        "이제 모델을 컴파일하고 `example_dataset`에서 평가할 수 있습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "c3Xd-52-9zGa"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel\n",
            "16195584/16193984 [==============================] - 0s 0us/step\n",
            "16203776/16193984 [==============================] - 0s 0us/step\n",
            "Evaluating on an example Shakespeare character: 0.402000\n",
            "Expected accuracy for random guessing: 0.012\n",
            "Evaluating on completely random data: 0.011\n"
          ]
        }
      ],
      "source": [
        "BATCH_SIZE = 8  # The training and eval batch size for the rest of this tutorial.\n",
        "keras_model = load_model(batch_size=BATCH_SIZE)\n",
        "keras_model.compile(\n",
        "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    metrics=[FlattenedCategoricalAccuracy()])\n",
        "\n",
        "# Confirm that loss is much lower on Shakespeare than on random data\n",
        "loss, accuracy = keras_model.evaluate(example_dataset.take(5), verbose=0)\n",
        "print(\n",
        "    'Evaluating on an example Shakespeare character: {a:3f}'.format(a=accuracy))\n",
        "\n",
        "# As a sanity check, we can construct some completely random data, where we expect\n",
        "# the accuracy to be essentially random:\n",
        "random_guessed_accuracy = 1.0 / len(vocab)\n",
        "print('Expected accuracy for random guessing: {a:.3f}'.format(\n",
        "    a=random_guessed_accuracy))\n",
        "random_indexes = np.random.randint(\n",
        "    low=0, high=len(vocab), size=1 * BATCH_SIZE * (SEQ_LENGTH + 1))\n",
        "data = collections.OrderedDict(\n",
        "    snippets=tf.constant(\n",
        "        ''.join(np.array(vocab)[random_indexes]), shape=[1, 1]))\n",
        "random_dataset = preprocess(tf.data.Dataset.from_tensor_slices(data))\n",
        "loss, accuracy = keras_model.evaluate(random_dataset, steps=10, verbose=0)\n",
        "print('Evaluating on completely random data: {a:.3f}'.format(a=accuracy))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lH0WzL5L8Lm4"
      },
      "source": [
        "## 페더레이션 학습으로 모델 미세 조정하기"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NCao4M3L_tsA"
      },
      "source": [
        "TFF는 모든 TensorFlow 계산을 직렬화하여 잠재적으로 Python이 아닌 환경에서 실행할 수 있습니다(현재로서는 Python으로 구현된 시뮬레이션 런타임만 사용할 수 있음). 즉시 모드(TF 2.0)에서 실행 중이지만, 현재 TFF는 \"`with tf.Graph.as_default()`\" 문의 컨텍스트 내에서 필요한 ops를 구성하여 TensorFlow 계산을 직렬화합니다. 따라서 TFF가 제어하는 ​​그래프에 모델을 도입하는 데 사용할 수 있는 함수를 제공해야 합니다. 다음과 같이 수행합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "5KadIvFp7m6y"
      },
      "outputs": [],
      "source": [
        "# Clone the keras_model inside `create_tff_model()`, which TFF will\n",
        "# call to produce a new copy of the model inside the graph that it will \n",
        "# serialize. Note: we want to construct all the necessary objects we'll need \n",
        "# _inside_ this method.\n",
        "def create_tff_model():\n",
        "  # TFF uses an `input_spec` so it knows the types and shapes\n",
        "  # that your model expects.\n",
        "  input_spec = example_dataset.element_spec\n",
        "  keras_model_clone = tf.keras.models.clone_model(keras_model)\n",
        "  return tff.learning.from_keras_model(\n",
        "      keras_model_clone,\n",
        "      input_spec=input_spec,\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "      metrics=[FlattenedCategoricalAccuracy()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZJF_yhJxAi2l"
      },
      "source": [
        "이제 모델을 개선하는 데 사용할 Federated Averaging 반복 프로세스를 구성할 준비가 되었습니다(Federated Averaging 알고리즘에 대한 자세한 내용은 분산 데이터에서 딥 네트워크의 [Communication-Efficient Learning](https://arxiv.org/abs/1602.05629) 논문 참조).\n",
        "\n",
        "컴파일된 Keras 모델을 사용하여 페더레이션 훈련의 각 라운드 후에 표준 (비 페더레이션) 평가를 수행합니다. 이는 시뮬레이션된 페더레이션 학습을 수행할 때 연구 목적으로 유용하며, 표준 테스트데이터 세트가 있습니다.\n",
        "\n",
        "현실적인 운영 환경에서 같은 기술을 사용하여 페러레이션 학습으로 훈련 된 모델을 테스트 또는 품질 보증 목적으로 중앙 집중식 벤치마크 데이터세트에서 평가할 수 있습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "my3PW3qhAMDA"
      },
      "outputs": [],
      "source": [
        "# This command builds all the TensorFlow graphs and serializes them: \n",
        "fed_avg = tff.learning.build_federated_averaging_process(\n",
        "    model_fn=create_tff_model,\n",
        "    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=0.5))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qVOkzs9C9kmv"
      },
      "source": [
        "다음은 단일 배치의 단일 클라이언트에서 한 라운드에 대해 페더레이션 평균화를 실행하는 가장 간단한 루프입니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "lrjUrkjq9jYk"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "loss=4.403, accuracy=0.132\n"
          ]
        }
      ],
      "source": [
        "state = fed_avg.initialize()\n",
        "state, metrics = fed_avg.next(state, [example_dataset.take(5)])\n",
        "print('loss={l:.3f}, accuracy={a:.3f}'.format(\n",
        "    l=metrics.train.loss, a=metrics.train.accuracy))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o2CjvVg0FZpS"
      },
      "source": [
        "이제 약간 더 흥미로운 훈련 및 평가 루프를 작성해 보겠습니다.\n",
        "\n",
        "이 시뮬레이션이 여전히 상대적으로 빠르게 실행되도록 각 라운드에 대해 2개의 미니 배치만을 고려하여 같은 3개의 클라이언트에 대해 훈련합니다.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "wE386-rbMCve"
      },
      "outputs": [],
      "source": [
        "def data(client, source=train_data):\n",
        "  return preprocess(source.create_tf_dataset_for_client(client)).take(5)\n",
        "\n",
        "\n",
        "clients = [\n",
        "    'ALL_S_WELL_THAT_ENDS_WELL_CELIA', 'MUCH_ADO_ABOUT_NOTHING_OTHELLO',\n",
        "]\n",
        "\n",
        "train_datasets = [data(client) for client in clients]\n",
        "\n",
        "# We concatenate the test datasets for evaluation with Keras by creating a \n",
        "# Dataset of Datasets, and then identity flat mapping across all the examples.\n",
        "test_dataset = tf.data.Dataset.from_tensor_slices(\n",
        "    [data(client, test_data) for client in clients]).flat_map(lambda x: x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cU3FuY00MOoX"
      },
      "source": [
        "`clone_model()`은 가중치를 복제하지 않으므로 `fed_avg.initialize()`에 의해 생성된 모델의 초기 상태는 로드된 가중치가 아니라 Keras 모델의 임의 이니셜라이저를 기반으로 합니다. 사전 훈련된 모델에서 훈련을 시작하기 위해 로드된 모델에서 직접 서버 상태의 모델 가중치를 설정합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "id": "vm_-PU8OFXpY"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Round 0\n",
            "\tEval: loss=3.324, accuracy=0.401\n",
            "\tTrain: loss=4.360, accuracy=0.155\n",
            "Round 1\n",
            "\tEval: loss=4.361, accuracy=0.049\n",
            "\tTrain: loss=4.235, accuracy=0.164\n",
            "Round 2\n",
            "\tEval: loss=4.219, accuracy=0.177\n",
            "\tTrain: loss=4.081, accuracy=0.221\n",
            "Round 3\n",
            "\tEval: loss=4.080, accuracy=0.174\n",
            "\tTrain: loss=3.940, accuracy=0.226\n",
            "Round 4\n",
            "\tEval: loss=3.991, accuracy=0.176\n",
            "\tTrain: loss=3.840, accuracy=0.226\n",
            "Final evaluation\n",
            "\tEval: loss=3.909, accuracy=0.171\n"
          ]
        }
      ],
      "source": [
        "NUM_ROUNDS = 5\n",
        "\n",
        "# The state of the FL server, containing the model and optimization state.\n",
        "state = fed_avg.initialize()\n",
        "\n",
        "# Load our pre-trained Keras model weights into the global model state.\n",
        "state = tff.learning.state_with_new_model_weights(\n",
        "    state,\n",
        "    trainable_weights=[v.numpy() for v in keras_model.trainable_weights],\n",
        "    non_trainable_weights=[\n",
        "        v.numpy() for v in keras_model.non_trainable_weights\n",
        "    ])\n",
        "\n",
        "\n",
        "def keras_evaluate(state, round_num):\n",
        "  # Take our global model weights and push them back into a Keras model to\n",
        "  # use its standard `.evaluate()` method.\n",
        "  keras_model = load_model(batch_size=BATCH_SIZE)\n",
        "  keras_model.compile(\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "      metrics=[FlattenedCategoricalAccuracy()])\n",
        "  state.model.assign_weights_to(keras_model)\n",
        "  loss, accuracy = keras_model.evaluate(example_dataset, steps=2, verbose=0)\n",
        "  print('\\tEval: loss={l:.3f}, accuracy={a:.3f}'.format(l=loss, a=accuracy))\n",
        "\n",
        "\n",
        "for round_num in range(NUM_ROUNDS):\n",
        "  print('Round {r}'.format(r=round_num))\n",
        "  keras_evaluate(state, round_num)\n",
        "  state, metrics = fed_avg.next(state, train_datasets)\n",
        "  train_metrics = metrics['train']\n",
        "  print('\\tTrain: loss={l:.3f}, accuracy={a:.3f}'.format(\n",
        "      l=train_metrics['loss'], a=train_metrics['accuracy']))\n",
        "\n",
        "print('Final evaluation')\n",
        "keras_evaluate(state, NUM_ROUNDS + 1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SoshvcHhXVa6"
      },
      "source": [
        "기본 변경으로 큰 차이를 만들 수 있는 충분한 훈련을 하지 않았지만, 더 많은 셰익스피어 데이터에 대해 더 오래 훈련하면 업데이트된 모델로 생성된 텍스트 스타일에서 차이를 볼 수 있습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "NTUig7QmXavy"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "What of TensorFlow Federated, you ask? Shalways, I will call your\n",
            "compet with any city brought their faces uncompany,\" besumed him. \"When he\n",
            "sticked Madame Defarge pushed the lamps.\n",
            "\n",
            "\"Have I often but no unison. She had probably come,\n"
          ]
        }
      ],
      "source": [
        "# Set our newly trained weights back in the originally created model.\n",
        "keras_model_batch1.set_weights([v.numpy() for v in keras_model.weights])\n",
        "# Text generation requires batch_size=1\n",
        "print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4DA1Fkf5mN0s"
      },
      "source": [
        "## 확장 제안\n",
        "\n",
        "이 튜토리얼은 첫 단계에 불과합니다! 이 노트북을 확장하는 방법에 대한 몇 가지 아이디어는 다음과 같습니다.\n",
        "\n",
        "- 무작위로 훈련할 클라이언트를 샘플링하는 보다 현실적인 훈련 루프를 작성합니다.\n",
        "- 클라이언트 데이터세트에서 \"`.repeat(NUM_EPOCHS)`\"를 사용하여 여러 epoch의 로컬 학습을 시도합니다(예: [McMahan et. al.](https://arxiv.org/abs/1602.05629)). 이를 수행하는 [이미지 분류를 위한 Federated Learning](federated_learning_for_image_classification.ipynb)도 참조하세요.\n",
        "- `compile()` 명령을 변경하여 클라이언트에서 다른 최적화 알고리즘을 사용하여 실험합니다.\n",
        "- `build_federated_averaging_process`에 `server_optimizer` 인수를 사용하여 서버에 모델 업데이트를 적용하기 위한 다른 알고리즘을 시도합니다.\n",
        "- `client_weight_fn` 인수를 `build_federated_averaging_process`에 사용하여 클라이언트의 다른 가중치를 시도합니다. 기본 가중치는 클라이언트의 예제 수에 따라 업데이트되지만, 예를 들어 `client_weight_fn=lambda _: tf.constant(1.0)`를 수행할 수 있습니다."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "federated_learning_for_text_generation.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
